import numpy as np
from itertools import combinations
from typing import Dict, Iterable, Tuple, Union, Optional

from sklearn.metrics.pairwise import cosine_similarity
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
import seaborn as sns


def analyze_cosine_similarity(
    century_embeddings: Dict[Union[int, str], np.ndarray],
    field_of_human_activity: str,
    visualize: bool = True,
    save: bool = False,
    output_path: Optional[str] = None,
) -> Tuple[float, float]:
    """
    analyze how similar century-level embeddings are as temporal distance increases.

    this function takes a dictionary of century embeddings, computes all pairwise cosine
    similarities between centuries, and relates these similarities to temporal distance
    (absolute difference in centuries). it then fits a linear relationship between
    temporal distance and cosine similarity, reports the pearson correlation
    and p-value, and optionally generates a scatter plot with regression line.

    parameters
    ----------
    century_embeddings:
        mapping from century label (e.g. -1, 1, 2, ...) to a 1d numpy embedding
        representing that century.
    field_of_human_activity:
        label used in the plot title to indicate which field the embeddings correspond to.
    visualize:
        if True, create and display a scatter plot with regression line.
    save:
        if True, save the figure to disk. only has an effect when `visualize` is True.
    output_path:
        path where the figure should be saved if `save` is True. if `save` is True
        and `output_path` is None, the figure is not saved.

    returns
    -------
    correlation:
        pearson correlation coefficient between temporal distance and cosine similarity.
    p_value:
        two-sided p-value associated with the correlation.

    raises
    ------
    ValueError
        if fewer than two centuries are provided, or fewer than two unique
        century pairs are available for correlation estimation.
    """
    # define a colorblind-friendly palette to keep plots accessible
    colorblind_palette = [
        "#117733", "#332288", "#DDCC77", "#CC6677", "#88CCEE",
        "#AA4499", "#44AA99", "#999933", "#882255", "#661100",
        "#6699CC", "#888888", "#F0E442", "#D55E00",
    ]

    if not century_embeddings:
        raise ValueError("century_embeddings must not be empty")

    # extract sorted centuries (for chronological order) and stack embeddings
    centuries = sorted(century_embeddings.keys())
    if len(centuries) < 2:
        raise ValueError("at least two centuries are required to compute pairwise similarity")

    embeddings = np.array([century_embeddings[c] for c in centuries])

    # validate that all embeddings have the same dimensionality
    if embeddings.ndim != 2:
        raise ValueError("all embeddings should be 1d arrays of equal length")
    if len({vec.shape for vec in embeddings}) != 1:
        raise ValueError("all embeddings must have the same dimensionality")

    # compute pairwise cosine similarities matrix between century embeddings
    similarities = cosine_similarity(embeddings)

    # collect pairwise similarities and temporal distances for unique century pairs
    sim_list = []
    dist_list = []
    for (i, c1), (j, c2) in combinations(enumerate(centuries), 2):
        sim = similarities[i, j]
        dist = abs(c2 - c1)
        sim_list.append(sim)
        dist_list.append(dist)

    if len(sim_list) < 2:
        raise ValueError("need at least two distinct pairs to compute a pearson correlation")

    # compute pearson correlation between temporal distance and cosine similarity
    correlation, p_value = pearsonr(dist_list, sim_list)
    r_squared = correlation**2

    print(f"pearson r = {correlation:.4f}")
    print(f"r² (variance explained) = {r_squared:.4f}")
    print(f"p-value = {p_value:.2e}")
    if p_value < 0.05 and correlation < 0:
        print("the observed negative correlation is statistically significant")
    else:
        print("the observed correlation is not statistically significant")

    # optionally visualize the relationship with a scatter plot and regression line
    if visualize:
        plt.style.use("ggplot")
        sns.set_context("notebook", font_scale=1.2)

        # scatter plot of temporal distance vs cosine similarity
        sns.scatterplot(
            x=dist_list,
            y=sim_list,
            color=colorblind_palette[1],
            s=50,
        )

        # add regression line (without re-plotting scatter points)
        sns.regplot(
            x=dist_list,
            y=sim_list,
            scatter=False,
            color=colorblind_palette[8],
        )

        # annotate the plot with correlation statistics
        plt.text(
            0.95,
            0.55,
            f"pearson r = {correlation:.4f}\n"
            f"p-value = {p_value:.2e}",
            fontsize=14,
            bbox={"facecolor": "white", "alpha": 0.8},
            transform=plt.gca().transAxes,
            ha="right",
            va="top",
        )

        plt.xlabel("temporal distance (|century₂ − century₁|)")
        plt.ylabel("cosine similarity")
        plt.title(
            f"temporal distance vs cosine similarity\n{field_of_human_activity}"
        )

        if save and output_path is not None:
            plt.savefig(
                output_path,
                dpi=500,
                transparent=True,
                bbox_inches="tight",
            )

        plt.show()

    return correlation, p_value